# analysis_delta.py
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn.functional as F
from utils import gaussian_gram_matrix, compute_von_neumann_entropy_from_density, compute_joint_entropy

def analyze_delta_hidden_mi(hidden_states_z, hidden_states_y, input_ids, context_lengths, lm_head_weight, sigma=1.0, save_dir=None, epoch_checkpoint=None, vmin=None, vmax=None):
    """
    Analyze mutual information between all pairs of layer differences and the target Y.
    Layer differences are defined as h(j) - h(i) for any two layers i and j.
    """
    # Create directory if it doesn't exist
    if save_dir:
        os.makedirs(save_dir, exist_ok=True)
        epoch_name = os.path.basename(epoch_checkpoint) if epoch_checkpoint else "current"
    
    # Number of samples and layers
    N = hidden_states_y[0].size(0)
    num_layers = len(hidden_states_y)
    
    # -----------------------
    # Compute Y: true label token embedding (using hidden_states)
    device = input_ids.device
    context_lengths_tensor = torch.tensor(context_lengths, device=device)
    target_token_ids = input_ids[torch.arange(N), context_lengths_tensor]
    
    # Use last hidden state of each context as Y
    Y = torch.stack([hidden_states_y[0][i, cl, :] for i, cl in enumerate(context_lengths)])  # [N, D]
    Y = Y.unsqueeze(1)  # reshape to [N, 1, D]
    Y_norm = F.normalize(Y, p=2, dim=-1)
    gram_y = gaussian_gram_matrix(Y_norm, sigma)
    gram_y_norm = gram_y / torch.trace(gram_y)
    H_Y = compute_von_neumann_entropy_from_density(gram_y_norm)
    
    # Initialize matrix to store mutual information values
    # Start from layer 1 since we need layer 0 to compute delta_h
    mi_matrix = np.zeros((num_layers, num_layers))
    
    # For each pair of layers, compute MI(h_j-h_i; Y)
    for i in range(num_layers):
        for j in range(num_layers):
            # Extract final token for each layer and compute delta
            # For any layer combination, compute delta_h = h(j layer) - h(i layer)
            if i > 0 or j > 0:  # Skip if both i and j are 0
                if i == j:
                    # For diagonal elements, compute I(h_i; Y)
                    # Extract hidden states for layer i
                    h_i = torch.stack([hidden_states_z[i][idx, cl - 1, :] 
                                    for idx, cl in enumerate(context_lengths)])  # [N, D]
                    
                    # Process hidden state
                    h_i = h_i.unsqueeze(1)  # reshape to [N, 1, D]
                    h_i_norm = F.normalize(h_i, p=2, dim=-1)
                    
                    # Compute entropy of h_i
                    gram_h_i = gaussian_gram_matrix(h_i_norm, sigma)
                    gram_h_i_norm = gram_h_i / torch.trace(gram_h_i)
                    H_Z = compute_von_neumann_entropy_from_density(gram_h_i_norm)
                    
                    # Compute joint entropy with Y
                    H_joint_ZY = compute_joint_entropy(gram_h_i, gram_y)
                    I_ZY = H_Z + H_Y - H_joint_ZY
                else:
                    # Extract hidden states for j layer
                    h_j = torch.stack([hidden_states_z[j][idx, cl - 1, :] 
                                        for idx, cl in enumerate(context_lengths)])  # [N, D]
                    
                    # Extract hidden states for i layer
                    h_i = torch.stack([hidden_states_z[i][idx, cl - 1, :] 
                                        for idx, cl in enumerate(context_lengths)])  # [N, D]
                    
                    # Compute delta_h = h(j layer) - h(i layer)
                    delta_h = h_j - h_i
                    
                    # Process delta hidden state
                    delta_h = delta_h.unsqueeze(1)  # reshape to [N, 1, D]
                    delta_h_norm = F.normalize(delta_h, p=2, dim=-1)
                    
                    # Compute entropy of delta_h
                    gram_delta = gaussian_gram_matrix(delta_h_norm, sigma)
                    gram_delta_norm = gram_delta / torch.trace(gram_delta)
                    H_Z = compute_von_neumann_entropy_from_density(gram_delta_norm)
                    
                    # Compute joint entropy with Y
                    H_joint_ZY = compute_joint_entropy(gram_delta, gram_y)
                    I_ZY = H_Z + H_Y - H_joint_ZY
            else:
                # If both i=0 and j=0, there's no delta
                H_Z = 0
                H_joint_ZY = 0
                I_ZY = 0
            
            # Store the MI value
            mi_matrix[i, j] = I_ZY
    
    # Create a DataFrame for easier CSV saving
    df = pd.DataFrame(mi_matrix)
    df.index = [f"Layer_{i}" for i in range(num_layers)]
    df.columns = [f"Layer_{i}" for i in range(num_layers)]
    
    # Add a description column to explain what each cell represents
    if save_dir:
        description_df = pd.DataFrame(np.zeros((num_layers, num_layers), dtype=object))
        for i in range(num_layers):
            for j in range(num_layers):
                if i == j:
                    description_df.iloc[i, j] = f"I(h_{i}; Y)"
                else:
                    description_df.iloc[i, j] = f"I(h_{j}-h_{i}; Y)"
        
        description_df.index = [f"Layer_{i}" for i in range(num_layers)]
        description_df.columns = [f"Layer_{i}" for i in range(num_layers)]
        
        # Save the descriptions to a separate CSV
        desc_csv_path = os.path.join(save_dir, f"path/to/your/file")
        description_df.to_csv(desc_csv_path)
        print(f"Saved delta hidden states MI descriptions.")
    
    # Save to CSV
    if save_dir:
        epoch_name = os.path.basename(epoch_checkpoint) if epoch_checkpoint else "current"
        csv_path = os.path.join(save_dir, f"path/to/your/file")
        df.to_csv(csv_path)
        print(f"Saved delta hidden states MI matrix.")
        
        # Create heatmap
        plt.figure(figsize=(28, 21))
        
        # Use a custom annotation that shows both the value and what it represents
        annotations = np.empty_like(mi_matrix, dtype=object)
        for i in range(num_layers):
            for j in range(num_layers):
                if i == j:
                    label = f"I(z_{i}; Y)"
                else:
                    label = f"I(z_{j}-z_{i}; Y)"
                
                value = mi_matrix[i, j]
                annotations[i, j] = f"{value:.3f}\n{label}"
        
        # Create heatmap with custom annotations and consistent color scale
        ax = sns.heatmap(mi_matrix, annot=annotations, fmt="", cmap="viridis",
                         xticklabels=[f"Layer {i}" for i in range(num_layers)],
                         yticklabels=[f"Layer {i}" for i in range(num_layers)],
                         vmin=vmin, vmax=vmax)  # Use provided vmin/vmax for consistent scale
        
        # Adjust font size of annotations
        for text in ax.texts:
            text.set_fontsize(8)
        
        plt.title(f"plot")
        plt.xlabel("Layer j")
        plt.ylabel("Layer i")
        
        # Save heatmap
        heatmap_path = os.path.join(save_dir, f"path/to/your/file")
        plt.tight_layout()
        plt.savefig(heatmap_path, dpi=300)
        plt.close()
        print(f"Saved delta hidden states MI heatmap.")
        
        # Also create a layer difference heatmap showing I(h_j-h_i; Y) for all i<j
        plt.figure(figsize=(28, 21))
        # Only plot the upper triangle (where j > i)
        upper_triangle = np.triu(mi_matrix, k=1)
        # Create a mask for the lower triangle including diagonal
        mask = np.tril(np.ones_like(mi_matrix), k=0)
        
        # Create a custom annotation array for the upper triangle
        annotations = np.empty_like(mi_matrix, dtype=object)
        for i in range(num_layers):
            for j in range(i+1, num_layers):  # Only j > i
                # annotations[i, j] = f"{mi_matrix[i, j]:.3f}\nI(h_{j}-h_{i}; Y)"
                annotations[i, j] = f"{mi_matrix[i, j]:.3f}"

        
        # Create the heatmap with consistent color scale
        ax = sns.heatmap(upper_triangle, annot=annotations, fmt="", cmap="viridis", mask=mask,
                         xticklabels=[f"Layer {i}" for i in range(num_layers)],
                         yticklabels=[f"Layer {i}" for i in range(num_layers)],
                         vmin=vmin, vmax=vmax, annot_kws={"fontsize": 14}
)  # Use provided vmin/vmax for consistent scale
        
        # Adjust font size of annotations
        for text in ax.texts:
            text.set_fontsize(8)
        
        plt.title(f"Information in Layer Differences (j > i) - Epoch {epoch_name}")
        plt.xlabel("Layer j")
        plt.ylabel("Layer i")
        
        # Save upper triangle heatmap
        triangle_path = os.path.join(save_dir, f"path/to/your/file")
        plt.tight_layout()
        plt.savefig(triangle_path, dpi=300)
        plt.close()
        print(f"Saved layer differences upper triangle heatmap.")
        
        # Create a plot showing information flow between consecutive layers
        plt.figure(figsize=(12, 6))
        consecutive_layers = []
        consecutive_mi = []
        
        for i in range(num_layers-1):
            consecutive_layers.append(f"{i}→{i+1}")
            consecutive_mi.append(mi_matrix[i, i+1])
        
        plt.bar(range(len(consecutive_layers)), consecutive_mi)
        plt.xlabel("Layer Transition")
        plt.ylabel("I(h_j-h_i; Y)")
        plt.title(f"plot")
        plt.xticks(range(len(consecutive_layers)), consecutive_layers, rotation=45)
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        
        # Add value labels on top of bars
        for i, v in enumerate(consecutive_mi):
            plt.text(i, v + 0.02, f"{v:.3f}", ha='center')
        
        # Save consecutive layers plot
        consecutive_path = os.path.join(save_dir, f"path/to/your/file")
        plt.tight_layout()
        plt.savefig(consecutive_path, dpi=300)
        plt.close()
        print(f"Saved information flow between consecutive layers.")
    
    # Return the results dictionary
    return {
        "mi_matrix": mi_matrix,
        "df": df,
        "H_Y": H_Y,
        "description_df": description_df if save_dir else None
    }
